library(tidyverse)
setwd("..")

#### step 0: check class imbalance
eo_tmp = read.csv("data/raw/eo_clean_full.csv")
stwts_tmp = read.csv("data/raw/stwts_clean_full.csv")


### First, get read in and clean the data

dat_eo = data.table::fread("results/eo_roc_mnb.csv") %>%
  mutate(V1 =  c(50, 100, 250, 500, 750, 1000,
                             1500, 2000, 2500, 3000)) %>%
  pivot_longer(cols = random_iter1:lda100_recon_iter10,
                            values_to = "AUC") %>%
  mutate(name = gsub("random_", "random_random_", name),
         name = gsub("cvec_", "cvec-", name)) %>%
  separate(name, into = c("representation",
                          "distance", "iter"),
           sep = "_") %>%
  filter(!is.na(AUC)) %>%
  group_by(V1, representation, distance) %>%
  summarize(avg_auc = mean(AUC),
            se_auc = sd(AUC)/sqrt(n()),
            auc_low = avg_auc - 1.97*se_auc,
            auc_upper = avg_auc + 1.97*se_auc)


dat_stwts = data.table::fread("results/stwts_roc_mnb.csv") %>%
  mutate(V1 =  c(50, 100, 250, 500, 750, 1000,
                 1500, 2000, 2500, 3000)) %>%
  pivot_longer(cols = random_iter1:lda100_recon_iter10,
               values_to = "AUC") %>%
  mutate(name = gsub("random_", "random_random_", name),
         name = gsub("cvec_", "cvec-", name)) %>%
  separate(name, into = c("representation",
                          "distance", "iter"),
           sep = "_") %>%
  filter(!is.na(AUC)) %>%
  group_by(V1, representation, distance) %>%
  summarize(avg_auc = mean(AUC),
            se_auc = sd(AUC)/sqrt(n()),
            auc_low = avg_auc - 1.97*se_auc,
            auc_upper = avg_auc + 1.97*se_auc)


## DFs to convert representations and distances to reader-friendly format

rep_df = data.frame(representation = unique(dat_eo$representation))
rep_df$Representation = c("BERT", "BoW + NMF", "BoW + PCA",
                          "BoW + tSNE", "BoW + UMAP",
                          "DistilBERT", "GLoVE", "LDA",
                          "Random", "RoBERTa", "Universal Encoder")


dist_df = data.frame(distance = unique(dat_eo$distance))
dist_df$Distance = c("Cosine", "d-Optimality", "K-L Divergence",
                     "k-Means", "K-S Distance", "Reconstruction Loss",
                     "Random")


dat_eo = dat_eo %>% left_join(dist_df) %>% left_join(rep_df) %>%
  select(-representation, -distance)

dat_stwts = dat_stwts %>% left_join(dist_df) %>% left_join(rep_df) %>%
  select(-representation, -distance)


ggplot(dat_eo, aes(x=V1, y=avg_auc)) + 
  geom_errorbar(aes(ymin=auc_low, ymax=auc_upper, colour=Distance), 
                width=.1, position=position_dodge(.1)) +
  geom_line( aes(x=V1, y=avg_auc, col = Distance)) + 
  geom_point(position=position_dodge(.1), size=1) +
  facet_wrap(~Representation) + theme_bw() + 
  xlab("Training Set Observations") + ylab("AUC") +
  ggtitle("Executive Orders")
ggsave("results/figA1A_eos.pdf")



ggplot(dat_stwts, aes(x=V1, y=avg_auc)) + 
  geom_errorbar(aes(ymin=auc_low, ymax=auc_upper, colour=Distance), 
                width=.1, position=position_dodge(.1)) +
  geom_line( aes(x=V1, y=avg_auc, col = Distance)) + 
  geom_point(position=position_dodge(.1), size=1) +
  facet_wrap(~Representation) + theme_bw()  + 
  xlab("Training Set Observations") + ylab("AUC") +
  ggtitle("StockTwits")
ggsave("results/figA1A_stwts.pdf")



## Full AUC results for both, plus SEs
tabA1_eo = dat_eo %>%
  mutate(Method = paste(Representation, Distance, sep = ", "),
         entry = paste0(round(avg_auc, 3), " (", round(se_auc,3), ")")) %>%
  select(V1, entry, Method) %>%
  pivot_wider(id_cols = Method, names_from = V1, values_from = entry)
  
tabA2_stwts = dat_stwts %>%
  mutate(Method = paste(Representation, Distance, sep = ", "),
         entry = paste0(round(avg_auc, 3), " (", round(se_auc,3), ")")) %>%
  select(V1, entry, Method) %>%
  pivot_wider(id_cols = Method, names_from = V1, values_from = entry)


print(xtable::xtable(tabA1_eo, auto=T),
      include.rowanmes=F,
      booktabs=T, floating=F,
      file="results/tabA1_fullresults_eo.tex",
      sanitize.text.function=function(x){x},
      format.args=list(big.mark=","))
print(xtable::xtable(tabA2_stwts, auto=T),
      include.rowanmes=F,
      booktabs=T, floating=F,
      file="results/tabA2_fullresults_stwts.tex",
      sanitize.text.function=function(x){x},
      format.args=list(big.mark=","))



## Rank correlation
#dat_eo$rankorder = rank(dat_eo$avg_auc)
#dat_stwts$rankorder = rank(dat_stwts$avg_auc)
#cor(dat_eo$rankorder, dat_stwts$rankorder, method="spearman")
#cor(dat_eo$avg_auc, dat_stwts$avg_auc)

## Next, run a regression
m_eo = data.table::fread("results/eo_roc_mnb.csv") %>%
  mutate(V1 =  c(50, 100, 250, 500, 750, 1000,
                 1500, 2000, 2500, 3000)) %>%
  pivot_longer(cols = random_iter1:lda100_recon_iter10,
               values_to = "AUC") %>%
  mutate(name = gsub("random_", "random_random_", name),
         name = gsub("cvec_", "cvec-", name)) %>%
  separate(name, into = c("representation",
                          "distance", "iter"),
           sep = "_") %>%
  filter(!is.na(AUC)) %>%
#  mutate(representation = relevel(factor(representation), "random"),
#         distance = relevel(factor(distance), "random")) %>%
  lm(AUC ~ representation + distance, data = .)

m_stwts = data.table::fread("results/stwts_roc_mnb.csv") %>%
  mutate(V1 =  c(50, 100, 250, 500, 750, 1000,
                 1500, 2000, 2500, 3000)) %>%
  pivot_longer(cols = random_iter1:lda100_recon_iter10,
               values_to = "AUC") %>%
  mutate(name = gsub("random_", "random_random_", name),
         name = gsub("cvec_", "cvec-", name)) %>%
  separate(name, into = c("representation",
                          "distance", "iter"),
           sep = "_") %>%
  filter(!is.na(AUC))%>%
 # mutate(representation = relevel(factor(representation), "random"),
#         distance = relevel(factor(distance), "random")) %>%
  lm(AUC ~ representation + distance, data = .)

## Do I want to add a combined model?

stargazer::stargazer(m_eo, m_stwts,
                     out = "results/tabA3_regressions.tex")




## Next, get the top performers
dat_eo2 = dat_eo %>%
  group_by(Representation, Distance) %>%
  summarize(auc_sum = sum(avg_auc)) %>%
  arrange(desc(auc_sum)) %>%
  mutate(repdist = paste(Representation, Distance, sep = ", "))
dat_eo2 = dat_eo2[c(1:2,37,16),]

dat_stwts2 = dat_stwts %>%
  group_by(Representation, Distance) %>%
  summarize(auc_sum = sum(avg_auc)) %>%
  arrange(desc(auc_sum)) %>%
  mutate(repdist = paste(Representation, Distance, sep = ", "))
dat_stwts2 = dat_stwts2[c(1:2,45,54),]

## Plot the 2 top performers, plus random and Taddy,
## and show that they are sig. better

dat_eo3 = dat_eo %>%
  mutate(repdist = paste(Representation, Distance, sep = ", ")) %>%
  filter(repdist %in% dat_eo2$repdist) %>%
  rename(Method = repdist)
dat_eo3$V1 = dat_eo3$V1 + rep(c(-30, -15,15, 30), 10)
dat_eo3$Method[dat_eo3$Method=="LDA, d-Optimality"] = "Taddy (2013)"
dat_eo3$Method[dat_eo3$Method=="Random, Random"] = "Random Sampling"

p_a1 = ggplot(dat_eo3, aes(x=V1, y=avg_auc)) + 
  geom_errorbar(aes(ymin=auc_low, ymax=auc_upper, colour=Method), 
                width=.1, position=position_dodge(.1)) +
  geom_line( aes(x=V1, y=avg_auc, col = Method)) + 
  geom_point(position=position_dodge(.1), size=1) +
  xlab("Training Set Observations") + ylab("AUC") +
  theme_bw() +
  theme(text = element_text(size=12)) +
  ggtitle("Executive Orders")
#ggsave("fig1C_eos.pdf")


dat_stwts3 = dat_stwts %>%
  mutate(repdist = paste(Representation, Distance, sep = ", ")) %>%
  filter(repdist %in% dat_stwts2$repdist) %>%
  rename(Method = repdist)
dat_stwts3$V1 = dat_stwts3$V1 + rep(c(-30, -15,15, 30), 10)
dat_stwts3$Method[dat_stwts3$Method=="LDA, d-Optimality"] = "Taddy (2013)"
dat_stwts3$Method[dat_stwts3$Method=="Random, Random"] = "Random Sampling"

p_a2 = ggplot(dat_stwts3, aes(x=V1, y=avg_auc)) + 
  geom_errorbar(aes(ymin=auc_low, ymax=auc_upper, colour=Method), 
                width=.1, position=position_dodge(.1)) +
  geom_line( aes(x=V1, y=avg_auc, col = Method)) + 
  geom_point(position=position_dodge(.1), size=1) +
  xlab("Training Set Observations") + ylab("AUC") +
  theme_bw() +
  theme(text = element_text(size=12)) +
  ggtitle("StockTwits")
#ggsave("fig1D_stwts.pdf")

library(cowplot)
tmp = plot_grid(p_a1, p_a2, nrow = 2)
save_plot("results/fig1.pdf", tmp, base_height =9)

## The component plots for the appendix

p1 = dat_eo %>% group_by(V1, Representation) %>%
  summarize(auc_mean = mean(avg_auc),
            auc_median = median(avg_auc)) %>%
  ggplot(aes(x = V1, y = auc_mean, col = Representation)) + 
  geom_line(size = 1) + 
  theme_bw() + xlab("Training Samples") +
  ylab("AUC")

p2 = dat_eo %>% group_by(V1, Distance) %>%
  summarize(auc_mean = mean(avg_auc),
            auc_median = median(avg_auc)) %>%
  ggplot(aes(x = V1, y = auc_mean, col = Distance)) + 
  geom_line(size =1) + 
  theme_bw() + xlab("Training Samples") +
  ylab("AUC")

tmp1=cowplot::plot_grid(p1, p2, ncol=2, align = "hv", rel_widths = c(4,4))
tmp1

ggsave("results/figA2A_component_results_eo.pdf")



p1 = dat_stwts %>% group_by(V1, Representation) %>%
  summarize(auc_mean = mean(avg_auc),
            auc_median = median(avg_auc)) %>%
  ggplot(aes(x = V1, y = auc_mean, col = Representation)) + 
  geom_line(size = 1) + 
  theme_bw() + xlab("Training Samples") +
  ylab("AUC")

p2 = dat_stwts %>% group_by(V1, Distance) %>%
  summarize(auc_mean = mean(avg_auc),
            auc_median = median(avg_auc)) %>%
  ggplot(aes(x = V1, y = auc_mean, col = Distance)) + 
  geom_line(size =1) + 
  theme_bw() + xlab("Training Samples") +
  ylab("AUC")

tmp1=cowplot::plot_grid(p1, p2, ncol=2, align = "hv", rel_widths = c(4,4))
tmp1

ggsave("results/figA2B_component_results_stwts.pdf")


